{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "import argparse\n",
    "import cv2\n",
    "import torch\n",
    "import numpy as np\n",
    "import glob\n",
    "from torch.cuda import amp\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "from controlnet_aux.processor import Processor as controlet_processor\n",
    "from anime_segmentation.train import AnimeSegmentation, net_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mask(model, input_img, use_amp=True, s=640):\n",
    "    input_img = (input_img / 255).astype(np.float32)\n",
    "    h, w = h0, w0 = input_img.shape[:-1]\n",
    "    h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)\n",
    "    ph, pw = s - h, s - w\n",
    "    img_input = np.zeros([s, s, 3], dtype=np.float32)\n",
    "    img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(\n",
    "        input_img, (w, h)\n",
    "    )\n",
    "    img_input = np.transpose(img_input, (2, 0, 1))\n",
    "    img_input = img_input[np.newaxis, :]\n",
    "    tmpImg = torch.from_numpy(img_input).type(torch.FloatTensor).to(model.device)\n",
    "    with torch.no_grad():\n",
    "        if use_amp:\n",
    "            with amp.autocast():\n",
    "                pred = model(tmpImg)\n",
    "            pred = pred.to(dtype=torch.float32)\n",
    "        else:\n",
    "            pred = model(tmpImg)\n",
    "        pred = pred.cpu().numpy()[0]\n",
    "        pred = np.transpose(pred, (1, 2, 0))\n",
    "        pred = pred[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]\n",
    "        pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]\n",
    "        return pred\n",
    "\n",
    "model = AnimeSegmentation.try_load(\n",
    "    \"isnet_is\",\n",
    "    \"/home/aihao/workspace/DeepLearningContent/models/anime-seg/isnetis.ckpt\",\n",
    "    \"cuda\",\n",
    "    img_size=1024,\n",
    ")\n",
    "model.eval()\n",
    "model.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for validation_path in os.listdir(\"../validation_images\"):\n",
    "    validation_path = os.path.join(\"../validation_images\", validation_path)\n",
    "    for image_path in os.listdir(validation_path):\n",
    "        if \"origin\" not in image_path:\n",
    "            continue\n",
    "        img = cv2.cvtColor(\n",
    "            cv2.imread(os.path.join(validation_path, image_path), cv2.IMREAD_COLOR),\n",
    "            cv2.COLOR_BGR2RGB,\n",
    "        )\n",
    "        mask = get_mask(\n",
    "            model, img, use_amp=False, s=min(1024, min(img.shape[0], img.shape[1]))\n",
    "        )\n",
    "        img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(\n",
    "            np.uint8\n",
    "        )\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)\n",
    "        if not os.path.exists(\n",
    "            os.path.join(validation_path, image_path.replace(\"origin\", \"capture\"))\n",
    "        ):\n",
    "            cv2.imwrite(\n",
    "                os.path.join(validation_path, image_path.replace(\"origin\", \"capture\")),\n",
    "                img,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for validation_path in os.listdir(\"../validation_images\"):\n",
    "    validation_path = os.path.join(\"../validation_images\", validation_path)\n",
    "    for image_path in os.listdir(validation_path):\n",
    "        if \"2_capture\" not in image_path:\n",
    "            continue\n",
    "        img = cv2.cvtColor(\n",
    "            cv2.imread(os.path.join(validation_path, image_path), cv2.IMREAD_COLOR),\n",
    "            cv2.COLOR_BGR2GRAY,\n",
    "        )\n",
    "        img = cv2.adaptiveThreshold(\n",
    "            img,\n",
    "            255,\n",
    "            cv2.ADAPTIVE_THRESH_MEAN_C,\n",
    "            cv2.THRESH_BINARY,\n",
    "            blockSize=5,\n",
    "            C=7,\n",
    "        )\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)\n",
    "        if not os.path.exists(\n",
    "            os.path.join(validation_path, image_path.replace(\"capture\", \"line\"))\n",
    "        ):\n",
    "            cv2.imwrite(\n",
    "                os.path.join(validation_path, image_path.replace(\"capture\", \"line\")),\n",
    "                img,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor_type=\"scribble_hed\"\n",
    "process = controlet_processor(processor_type)\n",
    "for validation_path in os.listdir(\"../validation_images\"):\n",
    "    validation_path = os.path.join(\"../validation_images\", validation_path)\n",
    "    for image_path in os.listdir(validation_path):\n",
    "        if \"2_capture\" not in image_path:\n",
    "            continue\n",
    "        img = Image.open(os.path.join(validation_path, image_path))\n",
    "        img = process(img, to_pil=True)\n",
    "        img.save(os.path.join(validation_path,image_path.replace('capture',processor_type)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "padding_image_path = \"/home/aihao/workspace/StableDiffusionReferenceOnly/validation_images/4/2_capture.png\"\n",
    "img = Image.open(padding_image_path).convert(\"RGBA\")\n",
    "side_length = max(img.size[0], img.size[1])\n",
    "filled_image = Image.new(\"RGBA\", (side_length, side_length), (0, 0, 0, 0))\n",
    "x_offset = (side_length - img.size[0]) // 2\n",
    "y_offset = (side_length - img.size[1]) // 2\n",
    "filled_image.paste(img, (x_offset, y_offset), img)\n",
    "filled_image.save(padding_image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filled_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
